{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import imageio\n",
    "img_arr = imageio.imread('../data/p1ch4/image-dog/bobby.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(720, 1280, 3)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_arr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = torch.from_numpy(img_arr)  # 转换为tensor\n",
    "out = img.permute(2,0,1)        #改变维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 做一个batch\n",
    "\n",
    "batch_size = 3\n",
    "batch = torch.zeros(batch_size,3,256,256,dtype = torch.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将图片放到batch\n",
    "import os\n",
    "data_dir ='../data/p1ch4/image-cats/'\n",
    "filenames =[name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] == '.png']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i,filename in enumerate(filenames):\n",
    "    img_arr =imageio.imread(os.path.join(data_dir,filename))    # 依次读取每个图片\n",
    "    img_t = torch.from_numpy(img_arr)\n",
    "    img_t = img_t.permute(2,0,1)\n",
    "    img_t = img_t[:3]\n",
    "    batch[i] = img_t                 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 标准化到0-1 \n",
    "batch = batch.float()\n",
    "batch /=255.0                      # 直接的方法是所有像素都除以255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 另一种方式是计算输入的均值和标准差并对其进行缩放以使每个通道的输出均值都为零，标准差为1 \n",
    "\n",
    "n_channels = batch.shape[1]\n",
    "for c in range(n_channels):\n",
    "    mean = torch.mean(batch[:,c])   # 对所有batch上的每一个通道计算均值\n",
    "    std = torch.std(batch[:,c])     # 计算标准差\n",
    "    batch[:,c] =(batch[:,c]-mean)/std\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
